{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src='http://www-scf.usc.edu/~ghasemig/images/sharif.png' alt=\"SUT logo\" width=200 height=200 align=left class=\"saturate\" >\n",
    "\n",
    "<br>\n",
    "<font face=\"Times New Roman\">\n",
    "<div dir=ltr align=center>\n",
    "<font color=0F5298 size=7>\n",
    "    Introduction to Machine Learning <br>\n",
    "<font color=2565AE size=5>\n",
    "    Computer Engineering Department <br>\n",
    "    Fall 2022<br>\n",
    "<font color=3C99D size=5>\n",
    "    Homework 3: Practical - PyTorch Classification <br>\n",
    "<font color=696880 size=4>\n",
    "    Javad Hezareh \n",
    "    \n",
    "    \n",
    "____\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full Name : \n",
    "### Student Number : \n",
    "___"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem\n",
    "This assignment uses PyTorch to build and implement a MLP model for solving a classification problem. Our goal is to classify galaxy images into 4 classes: ellipticals, lenticulars, spirals, and irregulars. We will use [EFIGI](https://www.astromatic.net/projects/efigi/) dataset which contains 4458 images.\n",
    "\n",
    "* It is highly recommended to run this notebook on Google Colab so that you can utilize its GPU.\n",
    "* If you need to change the inputs of functions you are implementing, or want to add new cells or functions, feel free to do so."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W0bGV8AXO3Yx"
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "CdrZr6HWO2p8"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import Dataset, DataLoader, random_split\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('ggplot')\n",
    "###########################################################\n",
    "##  If you need any other packages, import them below    ##\n",
    "###########################################################\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gzydGC2SNuAu"
   },
   "source": [
    "# Prepare and Visualize Data (10 Points)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ANdYnd04N0ze"
   },
   "source": [
    "Run the following cell to download dataset. `prepare_data` will return a pandas dataframe which contains three columns. `name` is the name of image that you can find that in `./efigi-1.6/png`, `class_name` is the type of galaxy in that image and `class_label` is a numerical label for this class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 242
    },
    "id": "DKTkjoAxDfoL",
    "outputId": "1d1583ef-b633-4b10-a7ba-75f9e8bc1051"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "efigi_png files already exists\n",
      "efigi_tables files already exists\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-c86fe277-9454-418d-8b4f-a89ea5de4f63\">\n",
       "    <div class=\"colab-df-container\">\n",
       "      <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>class_name</th>\n",
       "      <th>class_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>PGC0000212</td>\n",
       "      <td>Spirals</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>PGC0000218</td>\n",
       "      <td>Spirals</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>PGC0000243</td>\n",
       "      <td>Lenticulars</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>PGC0000255</td>\n",
       "      <td>Spirals</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>PGC0000281</td>\n",
       "      <td>Spirals</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c86fe277-9454-418d-8b4f-a89ea5de4f63')\"\n",
       "              title=\"Convert this dataframe to an interactive table.\"\n",
       "              style=\"display:none;\">\n",
       "        \n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "       width=\"24px\">\n",
       "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
       "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
       "  </svg>\n",
       "      </button>\n",
       "      \n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      flex-wrap:wrap;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "      <script>\n",
       "        const buttonEl =\n",
       "          document.querySelector('#df-c86fe277-9454-418d-8b4f-a89ea5de4f63 button.colab-df-convert');\n",
       "        buttonEl.style.display =\n",
       "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "        async function convertToInteractive(key) {\n",
       "          const element = document.querySelector('#df-c86fe277-9454-418d-8b4f-a89ea5de4f63');\n",
       "          const dataTable =\n",
       "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                     [key], {});\n",
       "          if (!dataTable) return;\n",
       "\n",
       "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "            + ' to learn more about interactive tables.';\n",
       "          element.innerHTML = '';\n",
       "          dataTable['output_type'] = 'display_data';\n",
       "          await google.colab.output.renderOutput(dataTable, element);\n",
       "          const docLink = document.createElement('div');\n",
       "          docLink.innerHTML = docLinkHtml;\n",
       "          element.appendChild(docLink);\n",
       "        }\n",
       "      </script>\n",
       "    </div>\n",
       "  </div>\n",
       "  "
      ],
      "text/plain": [
       "         name   class_name  class_label\n",
       "0  PGC0000212      Spirals            2\n",
       "1  PGC0000218      Spirals            2\n",
       "2  PGC0000243  Lenticulars            1\n",
       "3  PGC0000255      Spirals            2\n",
       "4  PGC0000281      Spirals            2"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# run this cell to download and prepare data\n",
    "from data_utils import download_data, prepare_data\n",
    "\n",
    "download_data()\n",
    "df = prepare_data('./efigi-1.6/EFIGI_attributes.txt')\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "xsiH1fRiMhBa"
   },
   "outputs": [],
   "source": [
    "##############################################################\n",
    "##            Visualize 4 sample from each class            ##\n",
    "##                        Your Code                         ##\n",
    "##############################################################\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wbUx3uHhQEVY"
   },
   "source": [
    "# Define Dataset (20 Points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "4CEUTGaFQI51"
   },
   "outputs": [],
   "source": [
    "###############################################################\n",
    "##        Write your dataset class for loading images        ##\n",
    "##                        Your Code                          ##\n",
    "###############################################################\n",
    "\n",
    "class GalaxyDataSet(Dataset):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def __len__(self):\n",
    "        pass\n",
    "\n",
    "    def __getitem__(self):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HF_40_3rQ5Uv"
   },
   "source": [
    "# Define Model (20 Points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "jQfGVj2rQ90G"
   },
   "outputs": [],
   "source": [
    "#####################################\n",
    "##        Define your model        ##\n",
    "##            Your Code            ##\n",
    "#####################################\n",
    "\n",
    "class ClassifierModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def forward(self, x):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1jZeKXV9Rbq3"
   },
   "source": [
    "# Train Model (30 Points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "koMk0TkKRazn"
   },
   "outputs": [],
   "source": [
    "######################################################################\n",
    "##        Instantiate model, define hyper parameters, optimizer,    ##\n",
    "##        loss function ant etc                                     ##\n",
    "######################################################################\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "_YSWvcEGSVor"
   },
   "outputs": [],
   "source": [
    "##############################################################\n",
    "##          Plot metrics graph for different epochs         ##\n",
    "##                        Your Code                         ##\n",
    "##############################################################\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LzhCrIFkSE1D"
   },
   "source": [
    "# Test Model (20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ihixveNISFwE"
   },
   "outputs": [],
   "source": [
    "##################################################\n",
    "##          Test your model on test-set         ##\n",
    "##          and plot confusion matrix           ##\n",
    "##################################################\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "81794d4967e6c3204c66dcd87b604927b115b27c00565d3d43f05ba2f3a2cb0d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
